{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Shor's algorithm in Cirq"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook presents a pedagogical demonstration of Shor's algorithm in Cirq. This tutorial is a modified and expanded version of [this Cirq example](https://github.com/quantumlib/Cirq/blob/master/examples/shor.py)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Imports for the notebook.\"\"\"\n",
    "import fractions\n",
    "import math\n",
    "import random\n",
    "\n",
    "import numpy as np\n",
    "import sympy\n",
    "from typing import Callable, List, Optional, Sequence, Union\n",
    "\n",
    "import cirq"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Order finding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Factoring an integer $n$ can be reduced to finding the period of the <i>modular exponential function</i> (to be defined). Finding this period can be accomplished (with high probability) by finding the <i>order</i> of a randomly chosen element of the multiplicative group modulo $n$.\n",
    "\n",
    "Let $n$ be a positive integer and \n",
    "\n",
    "\\begin{equation}\n",
    "\\mathbb{Z}_n := \\{x \\in \\mathbb{Z}_+ : x < n \\text{ and } \\text{gcd}(x, n) = 1\\}\n",
    "\\end{equation}\n",
    "\n",
    "be the multiplicative group modulo $n$.\n",
    "Given $x \\in \\mathbb{Z}_n$, compute the smallest positive integer $r$ such that $x^r \\text{ mod } n = 1$.\n",
    "\n",
    "It can be shown from group/number theory that:\n",
    "\n",
    "(1) Such an integer $r$ exists. (Note that $g^{|G|} = 1_G$ for any group $G$ with cardinality $|G|$ and element $g \\in G$., but it's possible that $r < |G|$.)\n",
    "\n",
    "(2) If $n = pq$ for primes $p$ and $q$, $|\\mathbb{Z}_n| = \\phi(n) = (p - 1) (q - 1)$. (The function $\\phi$ is called [Euler's totient function](https://en.wikipedia.org/wiki/Euler%27s_totient_function).)\n",
    "\n",
    "(3) The modular exponential function\n",
    "\n",
    "\\begin{equation}\n",
    "f_x(z) := x^z \\mod n\n",
    "\\end{equation}\n",
    "\n",
    "is periodic with period $r$ (the order of the element $x \\in \\mathbb{Z}_n$). That is, $f_x(z + r) = f_x(z)$. \n",
    "\n",
    "(4) If we know the period of the modular exponential function, we can (with high probability) figure out $p$ and $q$ -- that is, factor $n$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As a refresher, we can visualize the elements of some multiplicative groups $\\mathbb{Z}_n$ for integers $n$ via the following simple function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Function to compute the elements of Z_n.\"\"\"\n",
    "def multiplicative_group(n: int) -> List[int]:\n",
    "    \"\"\"Returns the multiplicative group modulo n.\n",
    "    \n",
    "    Args:\n",
    "        n: Modulus of the multiplicative group.\n",
    "    \"\"\"\n",
    "    assert n > 2\n",
    "    group = [1, 2]\n",
    "    for x in range(3, n):\n",
    "        if math.gcd(x, n) == 1:\n",
    "            group.append(x)\n",
    "    return group"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For example, the multiplicative group modulo $n = 15$ is shown below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The multiplicative group modulo n = 15 is:\n",
      "[1, 2, 4, 7, 8, 11, 13, 14]\n"
     ]
    }
   ],
   "source": [
    "\"\"\"Example of a multiplicative group.\"\"\"\n",
    "n = 15\n",
    "print(f\"The multiplicative group modulo n = {n} is:\")\n",
    "print(multiplicative_group(n))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One can check that this set of elements indeed forms a group (under ordinary multiplication)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Classical order finding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A function for classically computing the order $r$ of an element $x \\in \\mathbb{Z}_n$ is provided below. This function simply computes the sequence \n",
    "\n",
    "\\begin{align}\n",
    "    &x^2 \\text{ mod } n, \\\\\n",
    "    &x^3 \\text{ mod } n, \\\\\n",
    "    &x^4 \\text{ mod } n, \\\\\n",
    "    &\\ \\ \\ \\ \\ \\ \\ \\ \\vdots\n",
    "\\end{align}\n",
    "\n",
    "until an integer $r$ is found such that $x^r = 1 \\text{ mod } n$. Since $|\\mathbb{Z}_n| = \\phi(n)$, this algorithm for order finding has time complexity $O(\\phi(n))$ which is inefficient. (Roughly $O(2^{L / 2})$ where $L$ is the number of bits in $n$.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Function for classically computing the order of an element of Z_n.\"\"\"\n",
    "def classical_order_finder(x: int, n: int) -> Optional[int]:\n",
    "    \"\"\"Computes smallest positive r such that x**r mod n == 1.\n",
    "\n",
    "    Args:\n",
    "        x: Integer whose order is to be computed, must be greater than one\n",
    "           and belong to the multiplicative group of integers modulo n (which\n",
    "           consists of positive integers relatively prime to n),\n",
    "        n: Modulus of the multiplicative group.\n",
    "\n",
    "    Returns:\n",
    "        Smallest positive integer r such that x**r == 1 mod n.\n",
    "        Always succeeds (and hence never returns None).\n",
    "\n",
    "    Raises:\n",
    "        ValueError when x is 1 or not an element of the multiplicative\n",
    "        group of integers modulo n.\n",
    "    \"\"\"\n",
    "    # Make sure x is both valid and in Z_n\n",
    "    if x < 2 or x >= n or math.gcd(x, n) > 1:\n",
    "        raise ValueError(f\"Invalid x={x} for modulus n={n}.\")\n",
    "    \n",
    "    # Determine the order\n",
    "    r, y = 1, x\n",
    "    while y != 1:\n",
    "        y = (x * y) % n\n",
    "        r += 1\n",
    "    return r"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "An example of computing $r$ for a given $x \\in \\mathbb{Z}_n$ and given $n$ is shown in the code block below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x^r mod n = 8^4 mod 15 = 1\n"
     ]
    }
   ],
   "source": [
    "\"\"\"Example of (classically) computing the order of an element.\"\"\"\n",
    "n = 15  # The multiplicative group is [1, 2, 4, 7, 8, 11, 13, 14]\n",
    "x = 8\n",
    "r = classical_order_finder(x, n)\n",
    "\n",
    "# Check that the order is indeed correct\n",
    "print(f\"x^r mod n = {x}^{r} mod {n} = {x**r % n}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The quantum part of Shor's algorithm is order finding, but done via a quantum circuit, which we'll discuss below."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Quantum order finding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Quantum order finding is essentially quantum phase estimation with unitary $U$ that computes the modular exponential function $f_x(z)$ for some randomly chosen $x \\in \\mathbb{Z}_n$. The full details of how $U$ is computed in terms of elementary gates can be complex to unravel, especially on a first reading. In this tutorial, we'll use arithmetic operations in Cirq which can implement such a unitary $U$ without fully delving into the details of elementary gates.\n",
    "\n",
    "Below we first show an example of a simple arithmetic operation in Cirq (addition) then discuss the operation we care about (modular exponentiation)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Quantum arithmetic operations in Cirq"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we discuss an example of defining an arithmetic operation in Cirq, namely modular addition. This operation adds the value of the input register into the target register. More specifically, this operation acts on two qubit registers as\n",
    "\n",
    "\\begin{equation}\n",
    "|a\\rangle_i |b\\rangle_t \\mapsto |a\\rangle_i |a + b \\text{ mod } N_t \\rangle_t .\n",
    "\\end{equation}\n",
    "\n",
    "Here, the subscripts $i$ and $t$ denote <i>i</i>nput and <i>t</i>arget register, respectively, and $N_t$ is the dimension of the target register.\n",
    "\n",
    "To define this operation, called `Adder`, we inherit from `cirq.ArithmeticOperation` and override the four methods shown below. The main method is the `apply` method which defines the arithmetic. Here, we simply state the expression as $a + b$ instead of the more accurate $a + b \\text{ mod } N_t$ above -- the `cirq.ArithmeticOperation` class is able to deduce what we mean by simply $a + b$ since the operation must be reversible. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Example of defining an arithmetic (quantum) operation in Cirq.\"\"\"\n",
    "class Adder(cirq.ArithmeticOperation):\n",
    "    \"\"\"Quantum addition.\"\"\"\n",
    "    def __init__(self, target_register, input_register):\n",
    "        self.input_register = input_register\n",
    "        self.target_register = target_register\n",
    "    \n",
    "    def registers(self):\n",
    "        return self.target_register, self.input_register\n",
    "    \n",
    "    def with_registers(self, *new_registers):\n",
    "        return Adder(*new_registers)\n",
    "    \n",
    "    def apply(self, target_value, input_value):\n",
    "        return target_value + input_value"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have the class defined, we can use it in a circuit. The cell below creates two qubit registers, then sets the first register to be $|10\\rangle$ (in binary) and the second register to be $|01\\rangle$ (in binary) via $X$ gates. Then, we use the `Adder` operation, then measure all the qubits.\n",
    "\n",
    "Since $10 + 01 = 11$ (in binary), we expect to measure $|11\\rangle$ in the target register every time. Additionally, since we do not alter the input register, we expect to measure $|10\\rangle$ in the input register every time. In short, the only bitstring we expect to measure is $1011$. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Circuit:\n",
      "\n",
      "0: ───X───#3──────────────────────────────────────────M───\n",
      "          │\n",
      "1: ───────#4──────────────────────────────────────────M───\n",
      "          │\n",
      "2: ───────<__main__.Adder object at 0x7ff2b8c3d9b0>───M───\n",
      "          │\n",
      "3: ───X───#2──────────────────────────────────────────M───\n",
      "\n",
      "\n",
      "Measurement outcomes:\n",
      "\n",
      "   0  1  2  3\n",
      "0  1  0  1  1\n",
      "1  1  0  1  1\n",
      "2  1  0  1  1\n",
      "3  1  0  1  1\n",
      "4  1  0  1  1\n"
     ]
    }
   ],
   "source": [
    "\"\"\"Example of using an Adder in a circuit.\"\"\"\n",
    "# Two qubit registers\n",
    "qreg1 = cirq.LineQubit.range(2)\n",
    "qreg2 = cirq.LineQubit.range(2, 4)\n",
    "\n",
    "# Define the circuit\n",
    "circ = cirq.Circuit(\n",
    "    cirq.ops.X.on(qreg1[0]),\n",
    "    cirq.ops.X.on(qreg2[1]),\n",
    "    Adder(input_register=qreg1, target_register=qreg2),\n",
    "    cirq.measure_each(*qreg1),\n",
    "    cirq.measure_each(*qreg2)\n",
    ")\n",
    "\n",
    "# Display it\n",
    "print(\"Circuit:\\n\")\n",
    "print(circ)\n",
    "\n",
    "# Print the measurement outcomes\n",
    "print(\"\\n\\nMeasurement outcomes:\\n\")\n",
    "print(cirq.sample(circ, repetitions=5).data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the output of this code block, we first see the circuit which shows the initial $X$ gates, the `Adder` operation, then the final measurements. Next, we see the measurement outcomes which are all the bitstring $1011$ as expected."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It is also possible to see the unitary of the adder operation, which we do below. Here, we set the target register to be two qubits in the zero state, i.e. $|00\\rangle$. We specify the input register as the integer one which corresponds to the qubit register $|01\\rangle$. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ryan/programs/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py:5: ComplexWarning: Casting complex values to real discards the imaginary part\n",
      "  \"\"\"\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[0, 0, 0, 1],\n",
       "       [1, 0, 0, 0],\n",
       "       [0, 1, 0, 0],\n",
       "       [0, 0, 1, 0]], dtype=int32)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\"\"\"Example of the unitary of an Adder operation.\"\"\"\n",
    "cirq.unitary(\n",
    "    Adder(target_register=cirq.LineQubit.range(2),\n",
    "          input_register=1)\n",
    ").astype(np.int32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can understand this unitary as follows. The $i$th column of the unitary is the state $|i + 1 \\text{ mod } 4\\rangle$. For example, if we look at the $0$th column of the unitary, we see the state $|i + 1 \\text{ mod } 4\\rangle = |0 + 1 \\text{ mod } 4\\rangle = |1\\rangle$. If we look at the $1$st column of the unitary, we see the state $|i + 1 \\text{ mod } 4\\rangle = |1 + 1 \\text{ mod } 4\\rangle = |2\\rangle$. Similarly for the last two columns."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Modular exponential arithmetic operation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can define the modular exponential arithmetic operation in a similar way to the simple addition arithmetic operation, shown below. For the purposes of understanding Shor's algorithm, the most important part of the following code block is the `apply` method which defines the arithmetic operation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\"\"\"Defines the modular exponential operation used in Shor's algorithm.\"\"\"\n",
    "class ModularExp(cirq.ArithmeticOperation):\n",
    "    \"\"\"Quantum modular exponentiation.\n",
    "\n",
    "    This class represents the unitary which multiplies base raised to exponent\n",
    "    into the target modulo the given modulus. More precisely, it represents the\n",
    "    unitary V which computes modular exponentiation x**e mod n:\n",
    "\n",
    "        V|y⟩|e⟩ = |y * x**e mod n⟩ |e⟩     0 <= y < n\n",
    "        V|y⟩|e⟩ = |y⟩ |e⟩                  n <= y\n",
    "\n",
    "    where y is the target register, e is the exponent register, x is the base\n",
    "    and n is the modulus. Consequently,\n",
    "\n",
    "        V|y⟩|e⟩ = (U**e|y)|e⟩\n",
    "\n",
    "    where U is the unitary defined as\n",
    "\n",
    "        U|y⟩ = |y * x mod n⟩      0 <= y < n\n",
    "        U|y⟩ = |y⟩                n <= y\n",
    "    \"\"\"\n",
    "    def __init__(\n",
    "        self, \n",
    "        target: Sequence[cirq.Qid],\n",
    "        exponent: Union[int, Sequence[cirq.Qid]], \n",
    "        base: int,\n",
    "        modulus: int\n",
    "    ) -> None:\n",
    "        if len(target) < modulus.bit_length():\n",
    "            raise ValueError(f'Register with {len(target)} qubits is too small '\n",
    "                             f'for modulus {modulus}')\n",
    "        self.target = target\n",
    "        self.exponent = exponent\n",
    "        self.base = base\n",
    "        self.modulus = modulus\n",
    "\n",
    "    def registers(self) -> Sequence[Union[int, Sequence[cirq.Qid]]]:\n",
    "        return self.target, self.exponent, self.base, self.modulus\n",
    "\n",
    "    def with_registers(\n",
    "            self,\n",
    "            *new_registers: Union[int, Sequence['cirq.Qid']],\n",
    "    ) -> cirq.ArithmeticOperation:\n",
    "        if len(new_registers) != 4:\n",
    "            raise ValueError(f'Expected 4 registers (target, exponent, base, '\n",
    "                             f'modulus), but got {len(new_registers)}')\n",
    "        target, exponent, base, modulus = new_registers\n",
    "        if not isinstance(target, Sequence):\n",
    "            raise ValueError(\n",
    "                f'Target must be a qubit register, got {type(target)}')\n",
    "        if not isinstance(base, int):\n",
    "            raise ValueError(\n",
    "                f'Base must be a classical constant, got {type(base)}')\n",
    "        if not isinstance(modulus, int):\n",
    "            raise ValueError(\n",
    "                f'Modulus must be a classical constant, got {type(modulus)}')\n",
    "        return ModularExp(target, exponent, base, modulus)\n",
    "\n",
    "    def apply(self, *register_values: int) -> int:\n",
    "        assert len(register_values) == 4\n",
    "        target, exponent, base, modulus = register_values\n",
    "        if target >= modulus:\n",
    "            return target\n",
    "        return (target * base**exponent) % modulus\n",
    "\n",
    "    def _circuit_diagram_info_(\n",
    "            self,\n",
    "            args: cirq.CircuitDiagramInfoArgs,\n",
    "    ) -> cirq.CircuitDiagramInfo:\n",
    "        assert args.known_qubits is not None\n",
    "        wire_symbols: List[str] = []\n",
    "        t, e = 0, 0\n",
    "        for qubit in args.known_qubits:\n",
    "            if qubit in self.target:\n",
    "                if t == 0:\n",
    "                    if isinstance(self.exponent, Sequence):\n",
    "                        e_str = 'e'\n",
    "                    else:\n",
    "                        e_str = str(self.exponent)\n",
    "                    wire_symbols.append(\n",
    "                        f'ModularExp(t*{self.base}**{e_str} % {self.modulus})')\n",
    "                else:\n",
    "                    wire_symbols.append('t' + str(t))\n",
    "                t += 1\n",
    "            if isinstance(self.exponent, Sequence) and qubit in self.exponent:\n",
    "                wire_symbols.append('e' + str(e))\n",
    "                e += 1\n",
    "        return cirq.CircuitDiagramInfo(wire_symbols=tuple(wire_symbols))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the `apply` method, we see that we evaluate `(target * base**exponent) % modulus`. The `target` and the `exponent` depend on the values of the respective qubit registers, and the `base` and `modulus` are constant -- namely, the `modulus` is $n$ and the `base` is some $x \\in \\mathbb{Z}_n$. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The total number of qubits we will use is $3 (L + 1)$ where $L$ is the number of bits needed to store the integer $n$ to factor. The size of the unitary which implements the modular exponential is thus $4^{3(L + 1)}$. For a modest $n = 15$, the unitary requires storing $2^{30}$ floating point numbers in memory which is out of reach of most current standard laptops."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "To factor n = 15 which has L = 4 bits, we need 3L + 3 = 15 qubits.\n"
     ]
    }
   ],
   "source": [
    "\"\"\"Create the target and exponent registers for phase estimation,\n",
    "and see the number of qubits needed for Shor's algorithm.\n",
    "\"\"\"\n",
    "n = 15\n",
    "L = n.bit_length()\n",
    "\n",
    "# The target register has L qubits\n",
    "target = cirq.LineQubit.range(L)\n",
    "\n",
    "# The exponent register has 2L + 3 qubits\n",
    "exponent = cirq.LineQubit.range(L, 3 * L + 3)\n",
    "\n",
    "# Display the total number of qubits to factor this n\n",
    "print(f\"To factor n = {n} which has L = {L} bits, we need 3L + 3 = {3 * L + 3} qubits.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As with the simple adder operation, this modular exponential operation has a unitary which we can display (memory permitting) as follows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"See (part of) the unitary for a modular exponential operation.\"\"\"\n",
    "# Pick some element of the multiplicative group modulo n\n",
    "x = 5\n",
    "\n",
    "# Display (part of) the unitary. Uncomment if n is small enough\n",
    "# cirq.unitary(ModularExp(target, exponent, x, n))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using the modular exponentional operation in a circuit"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The quantum part of Shor's algorithm is just phase estimation with the unitary $U$ corresponding to the modular exponential operation. The following cell defines a function which creates the circuit for Shor's algorithm using the `ModularExp` operation we defined above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Function to make the quantum circuit for order finding.\"\"\"\n",
    "def make_order_finding_circuit(x: int, n: int) -> cirq.Circuit:\n",
    "    \"\"\"Returns quantum circuit which computes the order of x modulo n.\n",
    "\n",
    "    The circuit uses Quantum Phase Estimation to compute an eigenvalue of\n",
    "    the unitary\n",
    "\n",
    "        U|y⟩ = |y * x mod n⟩      0 <= y < n\n",
    "        U|y⟩ = |y⟩                n <= y\n",
    "\n",
    "    Args:\n",
    "        x: positive integer whose order modulo n is to be found\n",
    "        n: modulus relative to which the order of x is to be found\n",
    "\n",
    "    Returns:\n",
    "        Quantum circuit for finding the order of x modulo n\n",
    "    \"\"\"\n",
    "    L = n.bit_length()\n",
    "    target = cirq.LineQubit.range(L)\n",
    "    exponent = cirq.LineQubit.range(L, 3 * L + 3)\n",
    "    return cirq.Circuit(\n",
    "        cirq.X(target[L - 1]),\n",
    "        cirq.H.on_each(*exponent),\n",
    "        ModularExp(target, exponent, x, n),\n",
    "        cirq.QFT(*exponent, inverse=True),\n",
    "        cirq.measure(*exponent, key='exponent'),\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Using this function, we can visualize the circuit for a given $x$ and $n$ as follows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0: ────────ModularExp(t*7**e % 15)────────────────────────────\n",
      "           │\n",
      "1: ────────t1─────────────────────────────────────────────────\n",
      "           │\n",
      "2: ────────t2─────────────────────────────────────────────────\n",
      "           │\n",
      "3: ────X───t3─────────────────────────────────────────────────\n",
      "           │\n",
      "4: ────H───e0────────────────────────QFT^-1───M('exponent')───\n",
      "           │                         │        │\n",
      "5: ────H───e1────────────────────────#2───────M───────────────\n",
      "           │                         │        │\n",
      "6: ────H───e2────────────────────────#3───────M───────────────\n",
      "           │                         │        │\n",
      "7: ────H───e3────────────────────────#4───────M───────────────\n",
      "           │                         │        │\n",
      "8: ────H───e4────────────────────────#5───────M───────────────\n",
      "           │                         │        │\n",
      "9: ────H───e5────────────────────────#6───────M───────────────\n",
      "           │                         │        │\n",
      "10: ───H───e6────────────────────────#7───────M───────────────\n",
      "           │                         │        │\n",
      "11: ───H───e7────────────────────────#8───────M───────────────\n",
      "           │                         │        │\n",
      "12: ───H───e8────────────────────────#9───────M───────────────\n",
      "           │                         │        │\n",
      "13: ───H───e9────────────────────────#10──────M───────────────\n",
      "           │                         │        │\n",
      "14: ───H───e10───────────────────────#11──────M───────────────\n"
     ]
    }
   ],
   "source": [
    "\"\"\"Example of the quantum circuit for period finding.\"\"\"\n",
    "n = 15\n",
    "x = 7\n",
    "circuit = make_order_finding_circuit(x, n)\n",
    "print(circuit)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As previously described, we put the exponent register into an equal superposition via Hadamard gates. The $X$ gate on the last qubit in the target register is used for phase kickback. The modular exponential operation performs the sequence of controlled unitaries in phase estimation, then we apply the inverse quantum Fourier transform to the exponent register and measure to read out the result."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To illustrate the measurement results, we can sample from a smaller circuit. (Note that in practice we would never run Shor's algorithm with $n = 6$ because it is even. This is just an example to illustrate the measurement outcomes.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Raw measurements:\n",
      "exponent=01010001, 00000000, 00000000, 00000000, 00000000, 00000000, 00000000, 00000000, 00000000\n",
      "\n",
      "Integer in exponent register:\n",
      "   exponent\n",
      "0         0\n",
      "1       256\n",
      "2         0\n",
      "3       256\n",
      "4         0\n",
      "5         0\n",
      "6         0\n",
      "7       256\n"
     ]
    }
   ],
   "source": [
    "\"\"\"Measuring Shor's period finding circuit.\"\"\"\n",
    "circuit = make_order_finding_circuit(x=5, n=6)\n",
    "res = cirq.sample(circuit, repetitions=8)\n",
    "\n",
    "print(\"Raw measurements:\")\n",
    "print(res)\n",
    "\n",
    "print(\"\\nInteger in exponent register:\")\n",
    "print(res.data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We interpret each measured bitstring as an integer, but what do these integers tell us? In the next section we look at how to classically post-process to interpret them."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Classical post-processing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The integer we measure is close to $s / r$ where $r$ is the order of $x \\in \\mathbb{Z}_n$ and $0 \\le s < r$ is an integer. We use the continued fractions algorithm to determine $r$ from $s / r$ then return it if the order finding circuit succeeded, else we return `None`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_measurement(result: cirq.TrialResult, x: int, n: int) -> Optional[int]:\n",
    "    \"\"\"Interprets the output of the order finding circuit.\n",
    "\n",
    "    Specifically, it determines s/r such that exp(2πis/r) is an eigenvalue\n",
    "    of the unitary\n",
    "\n",
    "        U|y⟩ = |xy mod n⟩  0 <= y < n\n",
    "        U|y⟩ = |y⟩         n <= y\n",
    "    \n",
    "    then computes r (by continued fractions) if possible, and returns it.\n",
    "\n",
    "    Args:\n",
    "        result: trial result obtained by sampling the output of the\n",
    "            circuit built by make_order_finding_circuit\n",
    "\n",
    "    Returns:\n",
    "        r, the order of x modulo n or None.\n",
    "    \"\"\"\n",
    "    # Read the output integer of the exponent register\n",
    "    exponent_as_integer = result.data[\"exponent\"][0]\n",
    "    exponent_num_bits = result.measurements[\"exponent\"].shape[1]\n",
    "    eigenphase = float(exponent_as_integer / 2**exponent_num_bits)\n",
    "\n",
    "    # Run the continued fractions algorithm to determine f = s / r\n",
    "    f = fractions.Fraction.from_float(eigenphase).limit_denominator(n)\n",
    "    \n",
    "    # If the numerator is zero, the order finder failed\n",
    "    if f.numerator == 0:\n",
    "        return None\n",
    "    \n",
    "    # Else, return the denominator if it is valid\n",
    "    r = f.denominator\n",
    "    if x**r % n != 1:\n",
    "        return None\n",
    "    return r"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The next code block shows an example of creating an order finding circuit, executing it, then using the classical postprocessing function to determine the order. Recall that the quantum part of the algorithm succeeds with some probability. If the order is `None`, try re-running the cell a few times. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finding the order of x = 5 modulo n = 6\n",
      "\n",
      "Raw measurements:\n",
      "exponent=1, 0, 0, 0, 0, 0, 0, 0, 0\n",
      "\n",
      "Integer in exponent register:\n",
      "   exponent\n",
      "0       256\n",
      "\n",
      "Order r = 2\n",
      "x^r mod n = 5^2 mod 6 = 1\n"
     ]
    }
   ],
   "source": [
    "\"\"\"Example of the classical post-processing.\"\"\"\n",
    "# Set n and x here\n",
    "n = 6\n",
    "x = 5\n",
    "\n",
    "print(f\"Finding the order of x = {x} modulo n = {n}\\n\")\n",
    "measurement = cirq.sample(circuit, repetitions=1)\n",
    "print(\"Raw measurements:\")\n",
    "print(measurement)\n",
    "\n",
    "print(\"\\nInteger in exponent register:\")\n",
    "print(measurement.data)\n",
    "\n",
    "r = process_measurement(measurement, x, n)\n",
    "print(\"\\nOrder r =\", r)\n",
    "if r is not None:\n",
    "    print(f\"x^r mod n = {x}^{r} mod {n} = {x**r % n}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You should see that the order of $x = 5$ in $\\mathbb{Z}_6$ is $r = 2$. Indeed, $5^2 \\text{ mod } 6 = 25 \\text{ mod } 6 = 1$. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Quantum order finder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now define a streamlined function for the quantum version of order finding using the functions we have previously written. The quantum order finder below creates the circuit, executes it, and processes the measurement result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def quantum_order_finder(x: int, n: int) -> Optional[int]:\n",
    "    \"\"\"Computes smallest positive r such that x**r mod n == 1.\n",
    "    \n",
    "    Args:\n",
    "        x: integer whose order is to be computed, must be greater than one\n",
    "           and belong to the multiplicative group of integers modulo n (which\n",
    "           consists of positive integers relatively prime to n),\n",
    "        n: modulus of the multiplicative group.\n",
    "    \"\"\"\n",
    "    # Check that the integer x is a valid element of the multiplicative group\n",
    "    # modulo n\n",
    "    if x < 2 or n <= x or math.gcd(x, n) > 1:\n",
    "        raise ValueError(f'Invalid x={x} for modulus n={n}.')\n",
    "\n",
    "    # Create the order finding circuit\n",
    "    circuit = make_order_finding_circuit(x, n)\n",
    "    \n",
    "    # Sample from the order finding circuit\n",
    "    measurement = cirq.sample(circuit)\n",
    "    \n",
    "    # Return the processed measurement result\n",
    "    return process_measurement(measurement, x, n)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This completes our quantum implementation of an order finder, and the quantum part of Shor's algorithm."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The complete factoring algorithm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can use this quantum order finder (or the classical order finder) to complete Shor's algorithm. In the following code block, we add a few pre-processing steps which:\n",
    "\n",
    "(1) Check if $n$ is even,\n",
    "\n",
    "(2) Check if $n$ is prime,\n",
    "\n",
    "(3) Check if $n$ is a prime power,\n",
    "\n",
    "all of which can be done efficiently with a classical computer. Additionally, we add the last necessary post-processing step which uses the order $r$ to compute a non-trivial factor $p$ of $n$. This is achieved by computing $y = x^{r / 2} \\text{ mod } n$ (assuming $r$ is even), then computing $p = \\text{gcd}(y - 1, n)$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Functions for factoring from start to finish.\"\"\"\n",
    "def find_factor_of_prime_power(n: int) -> Optional[int]:\n",
    "    \"\"\"Returns non-trivial factor of n if n is a prime power, else None.\"\"\"\n",
    "    for k in range(2, math.floor(math.log2(n)) + 1):\n",
    "        c = math.pow(n, 1 / k)\n",
    "        c1 = math.floor(c)\n",
    "        if c1**k == n:\n",
    "            return c1\n",
    "        c2 = math.ceil(c)\n",
    "        if c2**k == n:\n",
    "            return c2\n",
    "    return None\n",
    "\n",
    "\n",
    "def find_factor(\n",
    "    n: int,\n",
    "    order_finder: Callable[[int, int], Optional[int]] = quantum_order_finder,\n",
    "    max_attempts: int = 30\n",
    ") -> Optional[int]:\n",
    "    \"\"\"Returns a non-trivial factor of composite integer n.\n",
    "\n",
    "    Args:\n",
    "        n: Integer to factor.\n",
    "        order_finder: Function for finding the order of elements of the\n",
    "            multiplicative group of integers modulo n.\n",
    "        max_attempts: number of random x's to try, also an upper limit\n",
    "            on the number of order_finder invocations.\n",
    "\n",
    "    Returns:\n",
    "        Non-trivial factor of n or None if no such factor was found.\n",
    "        Factor k of n is trivial if it is 1 or n.\n",
    "    \"\"\"\n",
    "    # If the number is prime, there are no non-trivial factors\n",
    "    if sympy.isprime(n):\n",
    "        print(\"n is prime!\")\n",
    "        return None\n",
    "    \n",
    "    # If the number is even, two is a non-trivial factor\n",
    "    if n % 2 == 0:\n",
    "        return 2\n",
    "    \n",
    "    # If n is a prime power, we can find a non-trivial factor efficiently\n",
    "    c = find_factor_of_prime_power(n)\n",
    "    if c is not None:\n",
    "        return c\n",
    "    \n",
    "    for _ in range(max_attempts):\n",
    "        # Choose a random number between 2 and n - 1\n",
    "        x = random.randint(2, n - 1)\n",
    "        \n",
    "        # Most likely x and n will be relatively prime\n",
    "        c = math.gcd(x, n)\n",
    "        \n",
    "        # If x and n are not relatively prime, we got lucky and found\n",
    "        # a non-trivial factor\n",
    "        if 1 < c < n:\n",
    "            return c\n",
    "        \n",
    "        # Compute the order r of x modulo n using the order finder\n",
    "        r = order_finder(x, n)\n",
    "        \n",
    "        # If the order finder failed, try again\n",
    "        if r is None:\n",
    "            continue\n",
    "        \n",
    "        # If the order r is even, try again\n",
    "        if r % 2 != 0:\n",
    "            continue\n",
    "        \n",
    "        # Compute the non-trivial factor\n",
    "        y = x**(r // 2) % n\n",
    "        assert 1 < y < n\n",
    "        c = math.gcd(y - 1, n)\n",
    "        if 1 < c < n:\n",
    "            return c\n",
    "\n",
    "    print(f\"Failed to find a non-trivial factor in {max_attempts} attempts.\")\n",
    "    return None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The function `find_factor` uses the `quantum_order_finder` by default, in which case it is executing Shor's algorithm. As previously mentioned, due to the large memory requirements for classically simulating this circuit, we cannot run Shor's algorithm for $n \\ge 15$. However, we can use the classical order finder as a substitute."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Factoring n = pq = 184573\n",
      "p = 487\n",
      "q = 379\n"
     ]
    }
   ],
   "source": [
    "\"\"\"Example of factoring via Shor's algorithm (order finding).\"\"\"\n",
    "# Number to factor\n",
    "n = 184573\n",
    "\n",
    "# Attempt to find a factor\n",
    "p = find_factor(n, order_finder=classical_order_finder)\n",
    "q = n // p\n",
    "\n",
    "print(\"Factoring n = pq =\", n)\n",
    "print(\"p =\", p)\n",
    "print(\"q =\", q)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
